# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import collections
from typing import Tuple, Union

import numpy as np
from tqdm import tqdm

Batch = collections.namedtuple(
    "Batch", ["observations", "actions", "rewards", "masks", "next_observations"]
)


def split_into_trajectories(
    observations, actions, rewards, masks, terminals_float, dones_float, next_observations
):
    trajs = [[]]

    for i in tqdm(range(len(observations))):
        trajs[-1].append(
            (
                observations[i],
                actions[i],
                rewards[i],
                masks[i],
                terminals_float[i],
                dones_float[i],
                next_observations[i],
            )
        )
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs


def merge_trajectories(trajs):
    observations = []
    actions = []
    rewards = []
    masks = []
    terminals_float = []
    dones_float = []
    next_observations = []

    for traj in trajs:
        for obs, act, rew, mask, terminal, done, next_obs in traj:
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            masks.append(mask)
            terminals_float.append(terminal)
            dones_float.append(done)
            next_observations.append(next_obs)

    return (
        np.stack(observations),
        np.stack(actions),
        np.stack(rewards),
        np.stack(masks),
        np.stack(terminals_float),
        np.stack(dones_float),
        np.stack(next_observations),
    )


class Dataset(object):
    def __init__(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        masks: np.ndarray,
        terminals_float: np.ndarray,
        dones_float: np.ndarray,
        next_observations: np.ndarray,
        size: int,
    ):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.terminals_float = terminals_float
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.size = size

    def sample(self, batch_size: int) -> Batch:
        indx = np.random.randint(self.size, size=batch_size)
        return Batch(
            observations=self.observations[indx],
            actions=self.actions[indx],
            rewards=self.rewards[indx],
            masks=self.masks[indx],
            next_observations=self.next_observations[indx],
        )

    def get_initial_states(
        self, and_action: bool = False
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        states = []
        if and_action:
            actions = []
        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.terminals_float,
            self.dones_float,
            self.next_observations,
        )

        def compute_returns(traj):
            episode_return = 0
            for _, _, rew, _, _, _ in traj:
                episode_return += rew

            return episode_return

        trajs.sort(key=compute_returns)

        for traj in trajs:
            states.append(traj[0][0])
            if and_action:
                actions.append(traj[0][1])

        states = np.stack(states, 0)
        if and_action:
            actions = np.stack(actions, 0)
            return states, actions
        else:
            return states

    def get_monte_carlo_returns(self, discount) -> np.ndarray:
        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.terminals_float,
            self.dones_float,
            self.next_observations,
        )
        mc_returns = []
        for traj in trajs:
            mc_return = 0.0
            for i, (_, _, reward, _, _, _) in enumerate(traj):
                mc_return += reward * (discount**i)
            mc_returns.append(mc_return)

        return np.asarray(mc_returns)

    def take_top(self, percentile: float = 100.0):
        assert percentile > 0.0 and percentile <= 100.0

        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.terminals_float,
            self.dones_float,
            self.next_observations,
        )

        def compute_returns(traj):
            episode_return = 0
            for _, _, rew, _, _, _ in traj:
                episode_return += rew

            return episode_return

        trajs.sort(key=compute_returns)

        N = int(len(trajs) * percentile / 100)
        N = max(1, N)

        trajs = trajs[-N:]

        (
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.terminals_float,
            self.dones_float,
            self.next_observations,
        ) = merge_trajectories(trajs)

        self.size = len(self.observations)

    def take_random(self, percentage: float = 100.0):
        assert percentage > 0.0 and percentage <= 100.0

        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.terminals_float,
            self.dones_float,
            self.next_observations,
        )
        np.random.shuffle(trajs)

        N = int(len(trajs) * percentage / 100)
        N = max(1, N)

        trajs = trajs[-N:]

        (
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.terminals_float,
            self.dones_float,
            self.next_observations,
        ) = merge_trajectories(trajs)

        self.size = len(self.observations)

    def train_validation_split(
        self, train_fraction: float = 0.8
    ) -> Tuple["Dataset", "Dataset"]:
        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.terminals_float,
            self.dones_float,
            self.next_observations,
        )
        train_size = int(train_fraction * len(trajs))

        # np.random.shuffle(trajs)

        (
            train_observations,
            train_actions,
            train_rewards,
            train_masks,
            train_terminals_float,
            train_dones_float,
            train_next_observations,
        ) = merge_trajectories(trajs[:train_size])

        (
            valid_observations,
            valid_actions,
            valid_rewards,
            valid_masks,
            valid_terminals_float,
            valid_dones_float,
            valid_next_observations,
        ) = merge_trajectories(trajs[train_size:])

        train_dataset = Dataset(
            train_observations,
            train_actions,
            train_rewards,
            train_masks,
            train_terminals_float,
            train_dones_float,
            train_next_observations,
            size=len(train_observations),
        )
        valid_dataset = Dataset(
            valid_observations,
            valid_actions,
            valid_rewards,
            valid_masks,
            valid_terminals_float,
            valid_dones_float,
            valid_next_observations,
            size=len(valid_observations),
        )

        return train_dataset, valid_dataset
